# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#      http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import os

import torch
from torch import distributed as dist
from torch.utils.data import Sampler

import nncf
from examples.torch.common.example_logger import logger


def configure_distributed(config):
    if config.dist_url == "env://" and config.rank == -1:
        config.rank = int(os.environ["RANK"])
    config.ngpus_per_node = torch.cuda.device_count()

    if config.current_gpu is not None:
        # Distributed multiprocessing
        config.rank = config.rank * config.ngpus_per_node + config.current_gpu
        # Must be called before execution of CUDA kernels to prevent failure the ones that allocate memory on the
        # default device (E.g. NMS kernel - https://github.com/facebookresearch/maskrcnn-benchmark/issues/74)
        torch.cuda.set_device(config.current_gpu)

    logger.info("| distributed init (rank {}): {}".format(config.rank, config.dist_url))
    dist.init_process_group(
        backend=config.dist_backend, init_method=config.dist_url, world_size=config.world_size, rank=config.rank
    )
    config.world_size = dist.get_world_size()


class DistributedSampler(Sampler):
    def __init__(self, dataset, rank=None, world_size=None):
        super().__init__(dataset)
        if world_size is None:
            if not dist.is_available():
                raise nncf.ValidationError("Requires distributed package to be available")
            world_size = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise nncf.ValidationError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.world_size = world_size
        self.rank = rank
        indices = list(range(len(dataset)))
        self.samples_per_rank = (len(indices) - 1) // self.world_size + 1
        self.indices = indices[self.rank * self.samples_per_rank : (self.rank + 1) * self.samples_per_rank]

        if len(self.indices) < self.samples_per_rank:
            # Workaround for mock datasets with a small number of entries
            pad = [0] * (self.samples_per_rank - len(self.indices))
            self.indices += pad

    def __iter__(self):
        return iter(self.indices)

    def __len__(self):
        return len(self.indices)
